{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import glob \n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from torch import optim\n",
    "from torch.utils.data import Dataset,DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "import librosa\n",
    "from torchaudio import transforms\n",
    "from model import WaveNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import IPython.display\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class VCTK(Dataset):\n",
    "    def __init__(self,path='./VCTK/',speaker='p225',transform=None,sr=16000,top_db=10):\n",
    "        self.wav_list = glob.glob(path + speaker +'/*.wav')\n",
    "        self.wav_ids = sorted([f.split('/')[-1] for f in glob.glob(path+'*')])\n",
    "        self.transform = transform\n",
    "        self.sr = sr\n",
    "        self.top_db = top_db\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        f = self.wav_list[index]\n",
    "        audio,_ = librosa.load(f,sr=self.sr,mono=True)\n",
    "        audio,_ = librosa.effects.trim(audio, top_db=self.top_db, frame_length=2048)\n",
    "        audio = np.clip(audio,-1,1)\n",
    "        wav_tensor = torch.from_numpy(audio).unsqueeze(1)\n",
    "        wav_id = f.split('/')[3]\n",
    "        if self.transform is not None:\n",
    "            wav_tensor = self.transform(wav_tensor)\n",
    "        \n",
    "        return wav_tensor\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.wav_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "t = transforms.Compose([\n",
    "        transforms.MuLawEncoding(),\n",
    "        transforms.LC2CL()])\n",
    "\n",
    "def collate_fn_(batch_data, max_len=40000):\n",
    "    audio = batch_data[0]\n",
    "    audio_len = audio.size(1)\n",
    "    if audio_len > max_len:\n",
    "        idx = random.randint(0,audio_len - max_len)\n",
    "        return audio[:,idx:idx+max_len]\n",
    "    else:\n",
    "        return audio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "vctk = VCTK(speaker='p225',transform=t,sr=16000)\n",
    "training_data = DataLoader(vctk,batch_size=1, shuffle=True,collate_fn=collate_fn_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = WaveNet().cuda()\n",
    "train_step = optim.Adam(model.parameters(),lr=2e-3, eps=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50,150,250], gamma=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for epoch in range(1000):\n",
    "    loss_= []\n",
    "    scheduler.step()\n",
    "    for data in training_data:\n",
    "        \n",
    "        data = Variable(data).cuda()\n",
    "        x = data[:,:-1]\n",
    "        \n",
    "        logits = model(x)\n",
    "        y = data[:,-logits.size(2):]\n",
    "        loss = F.cross_entropy(logits.transpose(1,2).contiguous().view(-1,256), y.view(-1))\n",
    "        train_step.zero_grad()\n",
    "        loss.backward()\n",
    "        train_step.step()\n",
    "        loss_.append(loss.data[0])\n",
    "    \n",
    "    print epoch,np.mean(loss_)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
